library(tidyverse)
library(MatchIt)
library(estimatr)
library(lmtest)
library(sandwich)

# get_treatment_effect calculates estimated effect, standard error, and p-values
# for as randomized, as treated, as matched analyses
# the latter two types assume a matched output as input from MatchIt
get_treatment_effect <- function(
  df, analysis_type, variable_str, control_variables = NULL
) {
  if (analysis_type == "as randomized") {
    df_ate <- df %>%
      group_by(date_batch) %>%
      summarize(
        n_0 = n() - sum(Z, na.rm = T),
        n_1 = sum(Z, na.rm = T),
        n_j = n(),
        yb_diff =
          mean(get(variable_str)[Z == 1], na.rm = T) - mean(
            get(variable_str)[Z == 0], na.rm = T
          ),
        s2_0 = var(get(variable_str)[Z == 0], na.rm = T),
        s2_1 = var(get(variable_str)[Z == 1], na.rm = T),
        se_j = s2_1 / n_1 + s2_0 / n_0
      ) %>%
      filter(n_0 >= 2 & n_1 >= 2)

    # filter out batches with only 1 control or treated case
    insufficient_batches <- df %>%
      group_by(date_batch) %>%
      summarize(n_0 = n() - sum(Z),
                n_1 = sum(Z)) %>%
      filter(!(n_0 >= 2 & n_1 >= 2)) %>%
      pull(date_batch)

    df_randomized <- df %>%
      filter(!(date_batch %in% insufficient_batches))

    # mean for treatment and control
    mu_0 <- mean(
      df_randomized %>%
        filter(Z == 0) %>%
        pull(get(variable_str)),
      na.rm = T
    )

    mu_1 <- mean(
      df_randomized %>%
        filter(Z == 1) %>%
        pull(get(variable_str)),
      na.rm = T
    )

    # number of units in each group (total)
    n_1 <- sum(df_ate$n_1, na.rm = T)
    n_0 <- sum(df_ate$n_0, na.rm = T)

    # weights for each block
    w <- df_ate$n_j / sum(df_ate$n_j, na.rm = T)
    w2 <- (df_ate$n_j / sum(df_ate$n_j, na.rm = T))^2

    # treatment effect estimate
    ate_hat <- w %*% (df_ate$yb_diff)

    # standard error estimate
    se_hat <- sqrt(w2 %*% df_ate$se_j)

    z_hat <- abs(ate_hat / se_hat)
    p_val <- (pnorm(-1 * z_hat, lower.tail = T) * 2)

    # rounded results
    ate_hat <- ate_hat %>% round(digits = 2)
    se_hat <- se_hat %>% round(digits = 2)
  } else if (analysis_type == "as treated") {
    df_ate <- df %>%
      filter(((Z == 0 & D == 0) | (Z == 1 & D == 1))) %>%
      group_by(date_batch) %>%
      summarize(
        n_0 = n() - sum(Z, na.rm = T),
        n_1 = sum(Z, na.rm = T),
        n_j = n(),
        yb_diff =
          mean(get(variable_str)[Z == 1], na.rm = T) - mean(
            get(variable_str)[Z == 0], na.rm = T
          ),
        s2_0 = var(get(variable_str)[Z == 0], na.rm = T),
        s2_1 = var(get(variable_str)[Z == 1], na.rm = T),
        se_j = s2_1 / n_1 + s2_0 / n_0
      ) %>%
      filter(n_0 >= 2 & n_1 >= 2)

    # filter out batches with only 1 control or treated case
    insufficient_batches <- df %>%
      filter(((Z == 0 & D == 0) | (Z == 1 & D == 1))) %>%
      group_by(date_batch) %>%
      summarize(n_0 = n() - sum(Z),
                n_1 = sum(Z)) %>%
      filter(!(n_0 >= 2 & n_1 >= 2)) %>%
      pull(date_batch)

    df_treated <- df %>%
      filter(((Z == 0 & D == 0) | (Z == 1 & D == 1))) %>%
      filter(!(date_batch %in% insufficient_batches))

    # mean for treatment and control
    mu_0 <- mean(
      df_randomized %>%
        filter(Z == 0) %>%
        pull(get(variable_str)),
      na.rm = T
    )

    mu_1 <- mean(
      df_randomized %>%
        filter(Z == 1) %>%
        pull(get(variable_str)),
      na.rm = T
    )

    # number of units in each group (total)
    n_1 <- sum(df_ate$n_1, na.rm = T)
    n_0 <- sum(df_ate$n_0, na.rm = T)

    # weights for each block
    w <- df_ate$n_j / sum(df_ate$n_j, na.rm = T)
    w2 <- (df_ate$n_j / sum(df_ate$n_j, na.rm = T))^2

    # treatment effect estimate
    ate_hat <- w %*% (df_ate$yb_diff)

    # standard error estimate
    se_hat <- sqrt(w2 %*% df_ate$se_j)

    z_hat <- abs(ate_hat / se_hat)
    p_val <- (pnorm(-1 * z_hat, lower.tail = T) * 2)

    # rounded results
    ate_hat <- ate_hat %>% round(digits = 2)
    se_hat <- se_hat %>% round(digits = 2)
  } else if (analysis_type == "as matched") {
    df_ate <- match.data(df)
    lm_fit <- lm(as.formula(paste0(variable_str, " ~ D")), data = df_ate)
    coefs <- coeftest(lm_fit, vcov. = vcovCL, cluster = ~ date_batch)
    n <- nobs(lm_fit)
    # Should be an equal number of matched cases.
    n_1 <- n / 2
    n_0 <- n / 2
    ate_hat <- coefs[2] %>% round(digits = 2)
    se_hat <- coefs[4] %>% round(digits = 2)
    p_val <- coefs[8]

    df_matched <- df_ate %>%
      filter(!is.na(get(variable_str)))
    # mean for treatment and control
    mu_0 <- mean(
      df_matched %>%
        filter(Z == 0) %>%
        pull(get(variable_str)),
      na.rm = T
    )

    mu_1 <- mean(
      df_matched %>%
        filter(Z == 1) %>%
        pull(get(variable_str)),
      na.rm = T
    )

  } else {
    return(NULL)
  }

  return(
    tibble(
      n_1 = c(n_1),
      n_0 = c(n_0),
      mu_1 = c(mu_1),
      mu_0 = c(mu_0),
      ate_hat = c(ate_hat),
      se_hat = c(se_hat),
      p_val = c(p_val)
    )
  )
}
